#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <time.h>
#include <mex.h>

#include "MersenneTwister.h"

using namespace std;

///////////////////////////////////////////////////////////////////////////////
// Constants
///////////////////////////////////////////////////////////////////////////////
#define R_BURN 10000

#define nN 64     // -h: max number of product line a firm can have
#define nNpow 6

#define N_FTYPES 2
#define FTYPE_LO 0
#define FTYPE_HI 1

#define N_DTYPES 3
#define DTYPE_GAIN 0
#define DTYPE_FREE 1
#define DTYPE_ENTR 2

#define nS 8             // -h: number of transition states
#define nSpow 3
#define STATE_NOTHING 0
#define STATE_LOSE    1
#define STATE_FREE    2
#define STATE_GAIN    3
#define STATE_AGE     4
#define STATE_EXIT    5

///////////////////////////////////////////////////////////////////////////////
// Structs
///////////////////////////////////////////////////////////////////////////////

typedef struct
{
  int* age_out;
  int* type_out;
  int* nprod_out;
  double* empl_out;
  int* exited_out;
  int* norig_out;
  double* empl_orig_out;
  int* last_state_out;
} info_out;

///////////////////////////////////////////////////////////////////////////////
// Device kernels
///////////////////////////////////////////////////////////////////////////////

// sequential search for small len
inline int dsamp1(double* pbeg, int len, double r)
{
  double* ppos = pbeg;
  double* pend = pbeg + len;
  int s = 0;
  while (ppos < pend) {
    if (r < (*ppos)) {
      break;
    }
    s++;
    ppos++;
  }
  return s;
}

// binary search (only powers of 2 sizes)
inline int dsamp2(double* pbeg, int pow2, double r)
{
  int len = 1<<(pow2-1); // -h: 2^(pow2-1)
  int pos = len-1;       

  double val;
  for (int s = 0; s < (pow2-1); s++) {
    len /= 2;								// -h: half of it
    val = pbeg[pos];
    pos += (r < val) ? -len : len;
  }

  val = pbeg[pos];
  pos += (r > val) ? 1 : 0;

  return pos;
}

///////////////////////////////////////////////////////////////////////////////
// MEX code
///////////////////////////////////////////////////////////////////////////////
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  // read in arguments
  if (nrhs != 17) {
    printf("Too few input arguments.\n");
    return;
  }

  if (nlhs != 13) {
    printf("Too few output arguments.\n");
    return;
  }

  // Handle matlab data
  // firmsim(nBpow,qbins_v{tid},qdists_v{tid},xt,qmint,taut,rhot,xoutt,epst,g,nu,psi,pqual,int32(R_PER_T),seed_v{tid});
 
  int nBpow         = mxGetScalar(prhs[0]);
  double* m_qbins   = (double*)mxGetData(prhs[1]); // N_FTYPES*nB
  double* m_qdists  = (double*)mxGetData(prhs[2]); // N_DTYPES*N_FTYPES*nB
  double* xs        = (double*)mxGetData(prhs[3]); // N_FTYPES
  double* qmins     = (double*)mxGetData(prhs[4]); // N_FTYPES
  double tau        = mxGetScalar(prhs[5]);
  double rho        = mxGetScalar(prhs[6]);
  double xout       = mxGetScalar(prhs[7]);
  double eps        = mxGetScalar(prhs[8]);
  double g          = mxGetScalar(prhs[9]);
  double nu         = mxGetScalar(prhs[10]);
  double psi        = mxGetScalar(prhs[11]);
  double alpha      = mxGetScalar(prhs[12]);
  int R_PER_T       = mxGetScalar(prhs[13]);
  unsigned int SEED = (unsigned int)mxGetScalar(prhs[14]);
  int T_PERIODS     = mxGetScalar(prhs[15]);
  int N_THREAD      = mxGetScalar(prhs[16]);

  int nB = 1<<nBpow;    
  int R_SIM = T_PERIODS*R_PER_T;
  int R_TOT = R_BURN+R_SIM;

  int qbN = mxGetN(prhs[1]);
  int qbM = mxGetM(prhs[1]);
  int qbL = max(qbN,qbM);
  if (qbL != nB) {
    printf("nBpow wrong.\n");
    return;
  }

  // initialize RNG
  MTRand* mt = new MTRand(SEED);

  // input values
  double delt  = 1.0/R_PER_T;   			// discretization
  double qdec  = 1.0/(1.0+delt*g);
  double pqual = alpha;
  int    nF    = 131072/N_THREAD; 

  // dists
  double* qdists[6] = {m_qdists,m_qdists+nB,m_qdists+2*nB,m_qdists+3*nB,m_qdists+4*nB,m_qdists+5*nB};
  double* qbins[2]  = {m_qbins,m_qbins+nB};

  // transition table
  int nel_tt = N_FTYPES*(nN+1)*nS;
  size_t ttsize = sizeof(double)*nel_tt;
  double* tvecs = (double*)malloc(ttsize);  

  double* svec;
  double ssum;
  double nut;
  double xt;
    
  for (int t = 0; t < N_FTYPES; t++) {
    nut = (t == 0) ? 0.0 : nu;  
    xt = xs[t];

    for (int n = 0; n < nN+1; n++) {  
      svec = tvecs + t*(nN+1)*nS + n*nS; 
      ssum = 1.0;
      svec[7] = ssum;
      svec[6] = ssum;
      svec[5] = ssum;
      ssum -= psi*delt;
      svec[4] = ssum;
      ssum -= nut*delt;    // -h: filling tvecs.
      svec[3] = ssum;
      ssum -= (n*xt)*delt;
      svec[2] = ssum;
      ssum -= (n*rho)*delt;
      svec[1] = ssum;
      ssum -= (n*tau)*delt;
      svec[0] = ssum;
    }
      
    if (ssum <= 0.0) printf("delt too large.");
  }  

  // output arrays
  mxArray* m_age 		    = mxCreateNumericMatrix(nF,T_PERIODS+1,mxINT32_CLASS,mxREAL);
  mxArray* m_type 		  = mxCreateNumericMatrix(nF,T_PERIODS+1,mxINT32_CLASS,mxREAL);
  mxArray* m_nprod 		  = mxCreateNumericMatrix(nF,T_PERIODS+1,mxINT32_CLASS,mxREAL);
  mxArray* m_empl 		  = mxCreateDoubleMatrix (nF,T_PERIODS+1,mxREAL);
  mxArray* m_exited 	  = mxCreateNumericMatrix(nF,T_PERIODS+1,mxINT32_CLASS,mxREAL);
  mxArray* m_norig 		  = mxCreateNumericMatrix(nF,T_PERIODS+1,mxINT32_CLASS,mxREAL);
  mxArray* m_empl_orig 	= mxCreateDoubleMatrix (nF,T_PERIODS+1,mxREAL);
  mxArray* m_last_state = mxCreateNumericMatrix(nF,T_PERIODS+1,mxINT32_CLASS,mxREAL);
  mxArray* m_ngain_rnd 	= mxCreateNumericMatrix(nF,T_PERIODS+1,mxINT32_CLASS,mxREAL);
  mxArray* m_ngain_res 	= mxCreateNumericMatrix(nF,T_PERIODS+1,mxINT32_CLASS,mxREAL);

  mxArray* m_qualEnd 		= mxCreateDoubleMatrix(nF,nN,mxREAL);
  mxArray* m_nlose 	    = mxCreateNumericMatrix(nF,T_PERIODS+1,mxINT32_CLASS,mxREAL);
  mxArray* m_qeps       = mxCreateDoubleMatrix(nF,T_PERIODS+1,mxREAL);
	
  int* h_age 			     = (int*)mxGetData(m_age);
  int* h_type 			   = (int*)mxGetData(m_type);
  int* h_nprod 			   = (int*)mxGetData(m_nprod);
  double* h_empl 		   = (double*)mxGetData(m_empl);
  int* h_exited 		   = (int*)mxGetData(m_exited);
  int* h_norig 			   = (int*)mxGetData(m_norig);
  double* h_empl_orig  = (double*)mxGetData(m_empl_orig);
  int* h_last_state 	 = (int*)mxGetData(m_last_state);
  int* h_ngain_rnd 		 = (int*)mxGetData(m_ngain_rnd);
  int* h_ngain_res 		 = (int*)mxGetData(m_ngain_res);
  
  
  int* h_nlose 		     = (int*)mxGetData(m_nlose);
  double* h_qualEnd 	 = (double*)mxGetData(m_qualEnd);
  double* h_qeps       = (double*)mxGetData(m_qeps);


  // output assignment
  plhs[0]  = m_age;
  plhs[1]  = m_type;
  plhs[2]  = m_nprod;
  plhs[3]  = m_empl;
  plhs[4]  = m_exited;
  plhs[5]  = m_norig;
  plhs[6]  = m_empl_orig;
  plhs[7]  = m_last_state;
  plhs[8]  = m_ngain_rnd;
  plhs[9]  = m_ngain_res;
  plhs[10] = m_nlose;
  plhs[11] = m_qualEnd;
  plhs[12] = m_qeps;


  //////////////////////
  // SIMULATION START //
  //////////////////////

  // local vars
  int s;
  int qind;
  int tech;
  int bin;
  double qval;
  double qvald;
  double qeps1;
  double r;
  double qmin;
  double* tvec;

  // tracking stats
  int prime_time = 0;    
  int r_sub;
  int period;
  int norig;
  int ngain_rnd = 0;
  int ngain_res = 0;
  int nlose=0;
  double emp;
  double emp_orig;
  int last_state = -1;
  int entry_ind;
  double qeps;

  // initial firm state
  int exited = 0;
  int age = 0;
  int type = 0;
  int n = 0;
  double quals[nN]; 
  int orig[nN];
  double qualDummy = 1;
  
  int rep;
  for (int f = 0; f < nF; f++) {  

    for (int rep = 0; rep <= R_BURN+R_SIM; rep++) { 

      // handle exits and first rep
      if (n == 0) {
        // store cause of exit
        if ((prime_time == 1) && (exited == 0)) last_state = s; 

        // init firm chars
        exited 	= 1;  
        age 	  = 0;     
        type 	  = (mt->rand() < pqual) ? FTYPE_HI : FTYPE_LO; 
      
        // first product
        while (qualDummy==1) {
        	qind = dsamp2(qdists[type*N_DTYPES+DTYPE_ENTR],nBpow,mt->rand()); 
        	qval = qbins[type][qind];
        	if (qval>=qmins[type]) {
        		qualDummy = 0;
        	}
        }
        quals[0] = qval;
        qualDummy = 1;

        n = 1;
      }

      // aggregate statistics
      if (rep >= R_BURN) {   
        if (rep == R_BURN) {
          prime_time = 1;
          period = 0;
          r_sub = 0;

          norig = 0;
          ngain_rnd = 0;
          ngain_res = 0;
          nlose     = 0;
          emp_orig = 0.0;
          exited = 0;

          for (qind = 0; qind < n; qind++) {
            orig[qind] = 1;  
          }                  
        }

        r_sub++; 

        if ((r_sub == R_PER_T) || (rep == R_BURN)) {
          emp = 0.0;                       
          emp_orig = 0.0;
          norig = 0;
          qeps = 0.0;
          for (qind = 0; qind < n; qind++) {
            qval = quals[qind];
            qeps1 = pow(qval,eps-1.0);
            emp += qeps1;  
            qeps += pow(qval,eps);
            if (orig[qind] == 1) {
              emp_orig += qeps1;  
              norig += 1;         
            }
          }

          h_age[period*nF+f]        = age;
          h_type[period*nF+f]       = type;
          h_nprod[period*nF+f]      = n;
          h_empl[period*nF+f]       = emp;
          h_exited[period*nF+f]     = exited;
          h_norig[period*nF+f]      = norig;
          h_empl_orig[period*nF+f]  = emp_orig;
          h_last_state[period*nF+f] = last_state;
          h_ngain_rnd[period*nF+f]  = ngain_rnd;
          h_ngain_res[period*nF+f]  = ngain_res;     
          h_nlose[period*nF+f]      = nlose;     
          h_qeps[period*nF+f]       = qeps;
          
          if (period==0){
          	for (qind = 0; qind < n; qind++) {
           	qval = quals[qind];
            	h_qualEnd[nF*qind + f] = quals[qind];
          	}
      	  }
          
          r_sub = 0; 
          period++;
        }
      }

      // the sampler
      tvec = tvecs + type*(nN+1)*nS + n*nS;  
      s = dsamp1(tvec,nS,mt->rand());
      r = mt->rand();
      switch (s) {
        case STATE_EXIT:
          n = 0;
          break;
        case STATE_AGE:
          type = 0;
          break;
        case STATE_GAIN:
          if (n < nN) {
            qind = dsamp2(qdists[type*N_DTYPES+DTYPE_GAIN],nBpow,r);   
            qval = qbins[type][qind];
            quals[n] = qval;   
            orig[n] = 0;       
            n++;
            ngain_rnd++;
          }
          break;
        case STATE_FREE: 
          if (n < nN) {
            qind = dsamp2(qdists[type*N_DTYPES+DTYPE_FREE],nBpow,r);
            qval = qbins[type][qind];
            quals[n] = qval;
            orig[n] = 0;
            n++;
            ngain_res++;
          }
          break;
        case STATE_LOSE:
          if (n > 0) {
            qind = floor(r*n);  
            qval = quals[qind]; 
            quals[qind] = quals[n-1];
            orig[qind]  = orig[n-1];      
            n--;
            nlose++;
          }
          break;
      }

      // decrement by growth
      qmin = qmins[type];
      for (qind = n-1; qind >= 0; qind--) {
        qval = quals[qind];
        qvald = qdec*qval;
        if (qvald < qmin) {
          quals[qind] = quals[n-1];
          orig[qind] = orig[n-1];
          n--;
          nlose++;
        } else {
          quals[qind] = qvald;
        }
      }

      // increment age
      age++; 
    }
  }

  // Free device memory
  free(tvecs);
}

